#include "iostream"
#include "cmath"
#include "vector"
#include "fstream"
#include "iomanip"

using namespace std;

struct point_3d
{
	double x, y, z;
};

struct source : point_3d
{
	double rad;
	double slow;
};

struct grid_point : point_3d
{
	int tag = 0; //0 = far away, 1 = close, 2 = active
	double time = 1e+30;
	double syn_time = 1e+30; //synthetic direct arrive time (only for error test)
	double slow;
	grid_point* neighbor[6] = {NULL,NULL,NULL,NULL,NULL,NULL}; //up down left right in out
};
typedef vector<grid_point> grid_point_array;
typedef vector<grid_point*> ptr_grid_point_array;

void update_heap(ptr_grid_point_array& a, int i, int n)
{
	int iMax = i,
		iLeft = 2 * i + 1,
		iRight = 2 * (i + 1);
	if (iLeft < n && a[iMax]->time < a[iLeft]->time) {
		iMax = iLeft;
	}
	if (iRight < n && a[iMax]->time < a[iRight]->time) {
		iMax = iRight;
	}
	if (iMax != i) {
		grid_point* tmp = a[iMax];
		a[iMax] = a[i];
		a[i] = tmp;
		update_heap(a, iMax, n);
	}
	return;
}

void heap_sort(ptr_grid_point_array& a, int n)
{
	for (int i = (n - 1) / 2; i >= 0; i--) {
		update_heap(a, i, n);
	}
	for (int i = n - 1; i > 0; --i) {
		
		grid_point* tmp = a[i];
		a[i] = a[0];
		a[0] = tmp;
		update_heap(a, 0, i);
	}
	return;
}

double dis_point_3d(point_3d a, point_3d b)
{
	return sqrt((a.x-b.x)*(a.x-b.x) + (a.y-b.y)*(a.y-b.y) + (a.z-b.z)*(a.z-b.z));
}

void local_update(grid_point* point_ptr, double delta_h)
{
	//solve a quadratic equation to get the trial time
	/*
	double a = 0, b = 0, c = -1.0 * pow(delta_h,2) * pow(point_ptr->slow,2);
	for (int i = 0; i < 6; i++)
	{
		if (point_ptr->neighbor[i] != NULL && point_ptr->neighbor[i]->tag == 2)
		{
			a += 1.0;
			b += -2.0*point_ptr->neighbor[i]->time;
			c += pow(point_ptr->neighbor[i]->time,2);
		}
	}
	*/
	// a second order upwind difference scheme
	
	double a = 0, b = 0, c = -4.0 * pow(delta_h,2) * pow(point_ptr->slow,2);
	for (int i = 0; i < 6; i++)
	{
		if (point_ptr->neighbor[i] != NULL && point_ptr->neighbor[i]->tag == 2)
		{
			a += 4.0;
			b += -8.0*point_ptr->neighbor[i]->time;
			c += 4.0*pow(point_ptr->neighbor[i]->time,2);
			if (point_ptr->neighbor[i]->neighbor[i] != NULL && point_ptr->neighbor[i]->neighbor[i]->tag == 2)
			{
				a += 5.0;
				b += -16.0*point_ptr->neighbor[i]->time + 6.0*point_ptr->neighbor[i]->neighbor[i]->time;
				c += 12.0*pow(point_ptr->neighbor[i]->time,2) - 8.0*point_ptr->neighbor[i]->time*point_ptr->neighbor[i]->neighbor[i]->time
					+ pow(point_ptr->neighbor[i]->neighbor[i]->time,2);
			}
		}
	}
	

	double delta = b*b - 4.0*a*c; //in a upwind scheme, delta is always bigger than zero
	point_ptr->time = min(point_ptr->time, 0.5*(sqrt(delta) - b)/a);
	return;
}

void abnormal_slowness(grid_point_array& grid_recall, double ab_slow,double xmin, double xmax, double ymin, double ymax, double zmin, double zmax)
{
	for (int i = 0; i < grid_recall.size(); i++)
	{
		if (grid_recall[i].x >= xmin && grid_recall[i].x <= xmax && 
			grid_recall[i].y >= ymin && grid_recall[i].y <= ymax && 
			grid_recall[i].z >= zmin && grid_recall[i].z <= zmax)
		{
			grid_recall[i].slow = ab_slow;
		}
	}
	return;
}

int main(int argc, char const *argv[])
{
	//set grid parameters
	int xnum = 101;
	int ynum = 51;
	int znum = 51;
	double xmin = 0;
	double ymin = 0;
	double zmin = 0;
	double dh = 10;

	//set source parameters
	source init_source;
	init_source.x = 50;
	init_source.y = 250;
	init_source.z = 250;
	init_source.rad = 30;
	init_source.slow = 1.0;

	//output name
	char ofilename[1024] = "out.msh";

	//initialize grid
	grid_point_array grid_3d;
	grid_3d.resize(xnum*ynum*znum);
	//down-left corner to up-right corner
	for (int k = 0; k < znum; k++)
	{
		for (int i = 0; i < ynum; i++)
		{
			for (int j = 0; j < xnum; j++)
			{

				grid_3d[i*xnum+j+k*xnum*ynum].x = xmin + dh*j;
				grid_3d[i*xnum+j+k*xnum*ynum].y = ymin + dh*i;
				grid_3d[i*xnum+j+k*xnum*ynum].z = zmin + dh*k;
				grid_3d[i*xnum+j+k*xnum*ynum].slow = init_source.slow;

				if (k <= znum-2) grid_3d[i*xnum+j+k*xnum*ynum].neighbor[0] = &grid_3d[i*xnum+j+(k+1)*xnum*ynum];
				if (k >= 1) grid_3d[i*xnum+j+k*xnum*ynum].neighbor[1] = &grid_3d[i*xnum+j+(k-1)*xnum*ynum];
				if (j >= 1) grid_3d[i*xnum+j+k*xnum*ynum].neighbor[2] = &grid_3d[i*xnum+j-1+k*xnum*ynum]; //left
				if (j <= xnum-2) grid_3d[i*xnum+j+k*xnum*ynum].neighbor[3] = &grid_3d[i*xnum+j+1+k*xnum*ynum]; //right
				if (i <= ynum-2) grid_3d[i*xnum+j+k*xnum*ynum].neighbor[4] = &grid_3d[(i+1)*xnum+j+k*xnum*ynum]; //up
				if (i >= 1) grid_3d[i*xnum+j+k*xnum*ynum].neighbor[5] = &grid_3d[(i-1)*xnum+j+k*xnum*ynum]; //down

				//calculate synthetic direct arrive time
				grid_3d[i*xnum+j+k*xnum*ynum].syn_time = dis_point_3d(grid_3d[i*xnum+j+k*xnum*ynum], init_source) * init_source.slow;
			}
		}
	}

	//add abnormal slowness here
	//abnormal_slowness(grid_3d,1e+30,200,250,0,500,0,250);
	//abnormal_slowness(grid_3d,1e+30,500,550,0,500,250,500);
	//abnormal_slowness(grid_3d,1e+30,800,850,0,500,0,250);

	ptr_grid_point_array close_node_ptr;
	//initialize source nodes and close nodes;
	for (int i = 0; i < xnum*ynum*znum; i++)
	{
		if (dis_point_3d(grid_3d[i],init_source) <= init_source.rad)
		{
			grid_3d[i].tag = 2;
			grid_3d[i].time = dis_point_3d(grid_3d[i],init_source) * init_source.slow;
		}
	}

	for (int i = 0; i < xnum*ynum*znum; i++)
	{
		if (grid_3d[i].tag == 2)
		{
			for (int j = 0; j < 6; j++)
			{
				if (grid_3d[i].neighbor[j] != NULL && grid_3d[i].neighbor[j]->tag == 0)
				{
					grid_3d[i].neighbor[j]->tag = 1;
					close_node_ptr.push_back(grid_3d[i].neighbor[j]);
				}
			}
		}
	}

	//calculate trial time for all close nodes
	for (int i = 0; i < close_node_ptr.size(); i++)
	{
		local_update(close_node_ptr[i], dh);
	}
	
	//marching forward and updating the close nodes set
	while (!close_node_ptr.empty())
	{
		// heap sort close nodes pointers to put the node first that has smallest time
		heap_sort(close_node_ptr,close_node_ptr.size());

		//change the first node's tag to 2 and update it's neighbor's time if it is not active
		close_node_ptr[0]->tag = 2;
		for (int i = 0; i < 6; i++)
		{
			if (close_node_ptr[0]->neighbor[i] != NULL && close_node_ptr[0]->neighbor[i]->tag == 0)
			{
				close_node_ptr[0]->neighbor[i]->tag = 1;
				local_update(close_node_ptr[0]->neighbor[i], dh);
				close_node_ptr.push_back(close_node_ptr[0]->neighbor[i]);
			}
			else if (close_node_ptr[0]->neighbor[i] != NULL && close_node_ptr[0]->neighbor[i]->tag == 1)
			{
				local_update(close_node_ptr[0]->neighbor[i], dh);
			}
		}
		close_node_ptr.erase(close_node_ptr.begin());
	}
	
	//output Gmsh(.msh) file
	ofstream outfile;
	outfile.open(ofilename);
	if (!outfile)
	{
		cerr << "file open error: " << ofilename << endl;
		return -1;
	}

	outfile<<"$MeshFormat"<<endl<<"2.2 0 8"<<endl<<"$EndMeshFormat"<<endl<<"$Nodes"<<endl<<xnum*ynum*znum<<endl;
	// set the first vertex index to 1 to avoid a display bug in Gmsh
	for (int i = 0; i < xnum*ynum*znum; i++)
	{
		outfile << i + 1 << " " << setprecision(16) << grid_3d[i].x << " " << grid_3d[i].y << " " << grid_3d[i].z << endl;
	}
	outfile<<"$EndNodes"<<endl;

	outfile<<"$Elements"<<endl<< (xnum-1)*(ynum-1)*(znum-1) <<endl;
	for (int k = 0; k < znum-1; k++)
	{
		for (int i = 0; i < ynum-1; i++)
		{
			for (int j = 0; j < xnum-1; j++)
			{
				outfile << 1 + j + i*(xnum-1) + k*(xnum-1)*(ynum-1) <<" 5 1 0 ";
				outfile << 1 + j + i*xnum + k*xnum*ynum << " "
						<< 1 + j + i*xnum + k*xnum*ynum + 1 << " "
						<< 1 + j + (i+1)*xnum + k*xnum*ynum + 1 << " "
						<< 1 + j + (i+1)*xnum + k*xnum*ynum << " "
						<< 1 + j + i*xnum + (k+1)*xnum*ynum << " "
						<< 1 + j + i*xnum + (k+1)*xnum*ynum + 1 << " "
						<< 1 + j + (i+1)*xnum + (k+1)*xnum*ynum + 1 << " "
						<< 1 + j + (i+1)*xnum + (k+1)*xnum*ynum << endl;
			}
		}
	}
	outfile << "$EndElements"<< endl;

	int tmp_node_num = 0;
	for (int i = 0; i < xnum*ynum*znum; i++)
	{
		if (grid_3d[i].time < 1e+10)
			tmp_node_num++;
	}
	outfile <<"$NodeData"<< endl;
	outfile<<1<<endl<<"\"fmm arrive time (ms)\""<<endl<<1<<endl<<0.0<<endl
	<<3<<endl<<0<<endl<<1<<endl<<tmp_node_num<<endl;
	for (int i = 0; i < xnum*ynum*znum; i++)
	{
		if (grid_3d[i].time < 1e+10)
		{
			outfile << i + 1 << " " << setprecision(16) << grid_3d[i].time << endl;
		}
	}
	outfile<<"$EndNodeData"<< endl;

	outfile <<"$NodeData"<< endl;
	outfile<<1<<endl<<"\"fmm - synthetic time (ms)\""<<endl<<1<<endl<<0.0<<endl
	<<3<<endl<<0<<endl<<1<<endl<<tmp_node_num<<endl;
	for (int i = 0; i < xnum*ynum*znum; i++)
	{
		if (grid_3d[i].time < 1e+10)
		{
			outfile << i + 1 << " " << setprecision(16) << grid_3d[i].time - grid_3d[i].syn_time << endl;
		}
	}
	outfile<<"$EndNodeData"<< endl;

	outfile <<"$NodeData"<< endl;
	outfile<<1<<endl<<"\"node tag\""<<endl<<1<<endl<<0.0<<endl
	<<3<<endl<<0<<endl<<1<<endl<<xnum*ynum*znum<<endl;
	for (int i = 0; i < xnum*ynum*znum; i++)
	{
		outfile << i + 1 << " " << setprecision(16) << grid_3d[i].tag << endl;
	}
	outfile<<"$EndNodeData"<< endl;

	outfile.close();
	return 0;
}